/**
 * cpair module
 * 
 * @author strlst <e11907086@student.tuwien.ac.at>
 * @date 2020-12-01
 * @brief contains necessary implementations of methods defined in the cpair header
 */

#include "cpair.h"

void find_closest_pair(char **argv, char *line, point *p, size_t n) {
    /* create pipe */
    int pipefd[4][2];
    for (int i = 0; i < 4; i++)
        if (pipe(pipefd[i]))
            die("%s: pipe failed\n", argv[0], EXIT_FAILURE);

    /* make sure */
    fflush(stdout);

    /* split into two child processes and integrate results */
    pid_t pid1, pid2;
    /* only parent process is going to execute both forks */
    if ((pid1 = fork()))
         pid2 = fork();

    /* error handling */
    if (pid1 < (pid_t) 0 || pid2 < (pid_t) 0) {
        die("%s: cannot fork\n", argv[0], EXIT_FAILURE);
    /* child1 process */
    } else if (pid1 == (pid_t) 0) {
        /* close unnecessary file descriptors */
        if (close(pipefd[0][0]) < 0 ||
            close(pipefd[1][1]) < 0 ||
            close(pipefd[2][1]) < 0 ||
            close(pipefd[3][0]) < 0 ||
            close(pipefd[2][0]) < 0 ||
            close(pipefd[3][1]) < 0)
            die("%s: child 1 process: error closing unnecessary file descriptor\n", argv[0], EXIT_FAILURE);
        /* set up pipe file descriptors */
        if (dup2(pipefd[0][1], STDOUT_FILENO) < 0 ||
            dup2(pipefd[1][0], STDIN_FILENO) < 0)
            die("%s: child 1 process: error duplicating file descriptor\n", argv[0], EXIT_FAILURE);
        /* close old file descriptors */
        if (close(pipefd[0][1]) < 0 ||
            close(pipefd[1][0]) < 0)
            die("%s: child 1 process: error closing old file descriptor\n", argv[0], EXIT_FAILURE);
        /* launch recursive process */
        if (execlp(argv[0], argv[0], NULL) < 0)
            die("%s: child 1 process: error spawning program\n", argv[0], EXIT_FAILURE);
    /* child2 process */
    } else if (pid2 == (pid_t) 0) {
        if (close(pipefd[0][0]) < 0 ||
            close(pipefd[1][1]) < 0 ||
            close(pipefd[0][1]) < 0 ||
            close(pipefd[1][0]) < 0 ||
            close(pipefd[2][0]) < 0 ||
            close(pipefd[3][1]) < 0)
            die("%s: child 1 process: error closing unnecessary file descriptor\n", argv[0], EXIT_FAILURE);
        /* set up pipe file descriptors */
        if (dup2(pipefd[2][1], STDOUT_FILENO) < 0 ||
            dup2(pipefd[3][0], STDIN_FILENO) < 0)
            die("%s: child 2 process: error duplicating file descriptor\n", argv[0], EXIT_FAILURE);
        /* close old file descriptors */
        if (close(pipefd[2][1]) < 0 ||
            close(pipefd[3][0]) < 0)
            die("%s: child 1 process: error closing old file descriptor\n", argv[0], EXIT_FAILURE);
        /* launch recursive process */
        if (execlp(argv[0], argv[0], NULL) < 0)
            die("%s: child 1 process: error spawning program\n", argv[0], EXIT_FAILURE);
    /* parent process */
    } else {
        /* close unnecessary file descriptors */
        if (close(pipefd[0][1]) < 0 ||
            close(pipefd[1][0]) < 0 ||
            close(pipefd[2][1]) < 0 ||
            close(pipefd[3][0]) < 0)
            die("%s: parent process: failed closing file descriptors\n", argv[0], EXIT_FAILURE);

        /* create streams to interact with child processes */
        FILE *stream[4];
        stream[0] = fdopen(pipefd[0][0], "r");
        stream[1] = fdopen(pipefd[1][1], "w");
        stream[2] = fdopen(pipefd[2][0], "r");
        stream[3] = fdopen(pipefd[3][1], "w");
        /* check results of fdopen */
        for (int i = 0; i < 4; i++)
            if (stream[i] == NULL)
                die("%s: parent process: failed opening stream\n", argv[0], EXIT_FAILURE);

        /* calculate arithmetic mean x_m */
        /* assumes double precision */
        /* algorithm by Knuth, Art of Computer Programming */
        double x_m = 0;
        for (int i = 0; i < n; i++)
            x_m += (p[i].x - x_m) / (i + 1);
        double x_m2 = 0;
        for (int i = 0; i < n; i++)
            x_m2 += p[i].x;
        x_m2 /= n;

        /* keep track of how many elements are in each bin */
        size_t n_bin1 = 0, n_bin2 = 0, points_bin1 = 0, points_bin2 = 0;

        /* make life easier */
        bubblesort(p, n);

        /* iterate elements to bin them */
        /* dry run to count points for allocation */
        for (int i = 0; i < n; i++)
            /* greater means bin2 */
            /* we also need to take care to balance very similar values */
            /* otherwise bin1 would grab all values <= x_m */
            if (p[i].x > x_m || (abs(p[i].x - x_m) < 1e-6 && points_bin1 > points_bin2))
                ++points_bin2;
            /* less or equal means bin1 */
            else
                ++points_bin1;

        //fprintf(stderr, "lol %li, %li, %li\n", n, points_bin1, points_bin2);
        /* divide points into two bins using x_m */
        point p_bin1[points_bin1];
        point p_bin2[points_bin2];

        /* check if bins were allocated */
        if (p_bin1 == NULL ||
            p_bin2 == NULL) {
            die("%s: could not allocate points bin buffers: %s\n", argv[0], EXIT_FAILURE);
        }

        /* iterate elements to bin them */
        for (int i = 0; i < n; i++) {
            /* greater means bin2 */
            /* we also need to take care to balance very similar values */
            /* otherwise bin1 would grab all values <= x_m */
            if (p[i].x > x_m || (abs(p[i].x - x_m) < 1e-6 && n_bin1 > n_bin2)) {
                p_bin2[n_bin2].x = p[i].x;
                p_bin2[n_bin2].y = p[i].y;
                ++n_bin2;
            /* less or equal means bin1 */
            } else {
                p_bin1[n_bin1].x = p[i].x;
                p_bin1[n_bin1].y = p[i].y;
                ++n_bin1;
            }
        }

        /* DEBUG */
/*
        fprintf(stderr, "bin1:\n");
        for (int i = 0; i < n_bin1; i++)
            printf("%le %le\n", p_bin1[i].x, p_bin1[i].y);
        fprintf(stderr, "bin2:\n");
        for (int i = 0; i < n_bin2; i++)
            printf("%le %le\n", p_bin2[i].x, p_bin2[i].y);
        fprintf(stderr, "end\n");
*/

        /* write first bin to child process 1 */
        for (int i = 0; i < n_bin1; i++)
            fprintf(stream[1], "%f %f\n", p_bin1[i].x, p_bin1[i].y);
        fflush(stream[1]);
        /* write first bin to child process 2 */
        for (int i = 0; i < n_bin2; i++)
            fprintf(stream[3], "%f %f\n", p_bin2[i].x, p_bin2[i].y);
        fflush(stream[3]);

        /* close streams we are done with */
        if (fclose(stream[1]) < 0)
            die("%s: parent process: error closing stream writing to child process 1\n", argv[0], EXIT_FAILURE);
        if (fclose(stream[3]) < 0)
            die("%s: parent process: error closing stream writing to child process 2\n", argv[0], EXIT_FAILURE);

        /* save best contenders for shortest pair */
        pair p1, p2, p3;

        /* TODO: perhaps find a better way to do this? */
        /* read results from children with error handling */
        p1 = try_read(line, stream[0]);
        p2 = try_read(line, stream[2]);

        /* close remaining streams */
        if (fclose(stream[0]) < 0)
            die("%s: parent process: error closing stream reading from child process 1\n", argv[0], EXIT_FAILURE);
        if (fclose(stream[2]) < 0)
            die("%s: parent process: error closing stream reading from child process 2\n", argv[0], EXIT_FAILURE);

        /* wait for children to finish playing */
        int status1, status2;
        waitpid(pid1, &status1, 0);
        waitpid(pid2, &status2, 0);
        /* gracefully handle error */
        if (status1 < EXIT_SUCCESS || status2 < EXIT_SUCCESS)
            die("%s: a child process has not terminated with EXIT_SUCCESS: %s\n", argv[0], EXIT_FAILURE);

        /* initialize min dist to be max double value */
        double p3_dist = DBL_MAX;
        /* go through all point combinations */
        for (int i = 0; i < n_bin1; i++) {
            for (int j = 0; j < n_bin2; j++) {
                /* calculate distance */
                double dist = euclidean_distance(p_bin1[i], p_bin2[j]);
                /* set lowest pair so far */
                if (dist < p3_dist) {
                    p3_dist = dist;
                    /* save pair */
                    p3.a = p_bin1[i];
                    p3.b = p_bin2[j];
                }
            }
        }

        /* get the best pair and print the results */
        pair best_pair = get_best_contender(p1, p2, p3, p3_dist);
        printf("%le %le\n%le %le\n", best_pair.a.x, best_pair.a.y, best_pair.b.x, best_pair.b.y);

         /* to be safe */
        fflush(stdout);

        /* the name says it all */
        print_tree(n, p, n_bin1, p_bin1, n_bin2, p_bin2);
    }
}

void die(char* message, char* program, int exit_code) {
    fprintf(stderr, message, program, strerror(errno));
    exit(exit_code);
}

double euclidean_distance(point a, point b) {
    return sqrt(pow(a.x - b.x, 2.) + pow(a.y - b.y, 2.));
}

pair try_read(char *line_buffer, FILE *stream) {
    char *ret;
    pair p, p_inv;

    /* init */
    p.a.x     = p.a.y     = p.b.x     = p.b.y     = INVALID;
    p_inv.a.x = p_inv.a.y = p_inv.b.x = p_inv.b.y = INVALID;

    ret = fgets(line_buffer, BUF_SIZE + 1, stream);
    if (ret == NULL)
        return p_inv;
    if (sscanf(line_buffer, "%le %le", &(p.a.x), &(p.a.y)) < 2)
        return p_inv;
    ret = fgets(line_buffer, BUF_SIZE + 1, stream);
    if (ret == NULL)
        return p_inv;
    if (sscanf(line_buffer, "%le %le", &(p.b.x), &(p.b.y)) < 2)
        return p_inv;
    return p;
}

int is_valid(point p) {
    if (p.x == INVALID ||
        p.y == INVALID)
        return 0;
    return 1;
}

pair get_best_contender(pair p1, pair p2, pair p3, double p3_dist) {
    /* compare points against one another */
    /* exclude P3, as duplicate points become a problem */
    point contenders[4];

    /* TODO: I wish this were more elegant */
    /* fill contenders */
    contenders[0].x = p1.a.x;
    contenders[0].y = p1.a.y;

    contenders[1].x = p1.b.x;
    contenders[1].y = p1.b.y;

    contenders[2].x = p2.a.x;
    contenders[2].y = p2.a.y;

    contenders[3].x = p2.b.x;
    contenders[3].y = p2.b.y;

    /* DEBUG */
/*
    fprintf(stderr, "lol: ");
    for (int i = 0; i < 6; i++)
        fprintf(stderr, "(%.2le %.2le) ", contenders[i].x, contenders[i].y);
    fprintf(stderr, "\n");
*/

    /* initializations */
    double min_dist = DBL_MAX;
    pair min_pair;
    min_pair.a.x = min_pair.a.y = min_pair.b.x = min_pair.b.y = 0.;
    /* go through all contender pairs */
    for (int i = 0; i < 4; i++) {
        for (int j = 1; j < 4; j++) {
            /* don't compare to self */
            if (i == j || !is_valid(contenders[i]) || !is_valid(contenders[j]))
                continue;
            double dist = euclidean_distance(contenders[i], contenders[j]);
            /* set new best */
            if (dist < min_dist) {
                min_dist = dist;
                min_pair.a = contenders[i];
                min_pair.b = contenders[j];
            }
        }
    }

    if (is_valid(p3.a) && is_valid(p3.b) && p3_dist < min_dist) {
        min_pair.a = p3.a;
        min_pair.b = p3.b;
    }

    return min_pair;
}

void bubblesort(point *p, int n) {  
    int i, j, unsorted = 0;
    /* avoid resorting already sorted point arrays */
    for (int k = 0; k < n - 2; k++) {
        if (p[k + 1].x < p[k].x) {
            unsorted = 1;
            break;
        }
    }

    /* pre-remptive return */
    if (!unsorted)
        return;

    for (i = 0; i < n - 1; i++) {
        for (j = 0; j < n - i - 1; j++) {
            if (p[j].x > p[j + 1].x) {
                /* swap p[j] with p[j + 1] */
                point temp;
                temp.x = p[j].x;
                temp.y = p[j].y;
                p[j].x = p[j + 1].x;
                p[j].y = p[j + 1].y;
                p[j + 1].x = temp.x;
                p[j + 1].y = temp.y;
            }
        }
    }
}

void print_tree(int n, point *p, int n_bin1, point *p_bin1, int n_bin2, point *p_bin2) {
    /* TODO: cleanup maybe? this is really messy */
    for (int i = n; i > 1; i /= 2)
        fprintf(stdout, "    ");
    fprintf(stdout, "CPAIR({");
    for (int i = 0; i < n - 1; i++)
        fprintf(stdout, "(%.2f, %.2f), ", p[i].x, p[i].y);
    fprintf(stdout, "(%.2f, %.2f)})\n", p[n - 1].x, p[n - 1].y);
    for (int i = n * 2; i > 1; i /= 2)
        fprintf(stdout, "    ");
    fprintf(stdout, "/");
    for (int i = 0; i < n; i++)
        /* approximate amount of spaces a point representation would use */
        fprintf(stdout, "            ");
    /* approximate amount of spaces the "CPAIR({})" would use */
    fprintf(stdout, "         \\\n");
    for (int i = n / 2; i > 1; i /= 2)
        fprintf(stdout, "    ");
    fprintf(stdout, "CPAIR({");
    for (int i = 0; i < n_bin1 - 1; i++)
        fprintf(stdout, "(%.2f, %.2f), ", p_bin1[i].x, p_bin1[i].y);
    fprintf(stdout, "(%.2f, %.2f)})", p_bin1[n_bin1 - 1].x, p_bin1[n_bin1 - 1].y);
    for (int i = (n * 2) / 2; i > 1; i /= 2)
        fprintf(stdout, "    ");
    fprintf(stdout, "CPAIR({");
    for (int i = 0; i < n_bin2 - 1; i++)
        fprintf(stdout, "(%.2f %.2f), ", p_bin2[i].x, p_bin2[i].y);
    fprintf(stdout, "(%.2f, %.2f)})\n", p_bin2[n_bin2 - 1].x, p_bin2[n_bin2 - 1].y);
}
